Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#432] Add Groq Provider - chat completions #609

Merged
merged 1 commit into from
Jan 3, 2025

Conversation

aidando73
Copy link
Contributor

@aidando73 aidando73 commented Dec 11, 2024

What does this PR do?

Contributes towards issue (#432)

  • Groq text chat completions
  • Streaming
  • All the sampling params that Groq supports

A lot of inspiration taken from @mattf's good work at #355

What this PR does not do

  • Tool calls (Future PR)
  • Adding llama-guard model
  • See if we can add embeddings

PR Train

Test Plan

Environment
export GROQ_API_KEY=<api_key>

wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml

# Build and run environment
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
  --port 5001
Manual tests

Using this jupyter notebook to test manually: https://github.com/aidando73/llama-stack/blob/2140976d76ee7ef46025c862b26ee87585381d2a/hello.ipynb

Use this code to test passing in the api key from provider_data

from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5001",
)

response = client.inference.chat_completion(
    model_id="Llama3.2-3B-Instruct",
    messages=[
        {"role": "user", "content": "Hello, world client!"},
    ],
    # Test passing in groq_api_key from the client
    # Need to comment out the groq_api_key in the run.yaml file
    x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}',
    # stream=True,
)
response
Integration

pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k groq

(run in same environment)

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED                 [  6%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED  [ 37%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED                 [ 56%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED  [ 87%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%]

======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ========================================
Unit tests

pytest llama_stack/providers/tests/inference/groq/ -v

llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED            [  5%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED       [ 40%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED                [ 45%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED  [ 60%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED        [ 65%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED                            [ 95%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED                   [100%]

==================================================== 20 passed, 11 warnings in 0.08s =====================================================

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Ran pre-commit to handle lint / formatting issues.
  • Read the contributor guideline,
    Pull Request section?
  • Updated relevant documentation
  • Wrote necessary unit or integration tests.

@aidando73 aidando73 marked this pull request as draft December 11, 2024 23:39
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 11, 2024
elif finish_reason == "length":
return StopReason.end_of_message
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users won't be able to hit this error yet since they can't pass tools as a parameter

"remote::groq",
):
pytest.skip(provider.__provider_spec__.provider_type + " doesn't support tool calling yet")

Copy link
Contributor Author

@aidando73 aidando73 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per comment above: https://github.com/meta-llama/llama-stack/pull/609/files#r1881443869

Will remove after I implement

warnings.warn("repetition_penalty is not supported")

if request.tools:
warnings.warn("tools are not supported yet")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m planning to handle tool calls in a separate PR since there are edge cases I want to cover properly.

But lmk if you want me to include it within this PR

@aidando73 aidando73 changed the title Add Groq Inference Provider Add Groq Provider - chat completions Dec 12, 2024
@aidando73 aidando73 force-pushed the aidand-432-groq_2 branch 2 times, most recently from ec8b47b to 3f2498e Compare December 12, 2024 10:28
@aidando73 aidando73 marked this pull request as ready for review December 12, 2024 10:53
Copy link
Contributor

@mattf mattf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good!

@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - the groq library will read GROQ_API_KEY env (https://github.com/groq/groq-python/blob/main/src/groq/_client.py#L86), consider adding a comment here so people in the LS codebase know this expectation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not rely on environment variables for code that we expect to run in llama-stack server. We would want to take in the api key as a config variable in run.yaml when we spin up the server

Copy link
Contributor Author

@aidando73 aidando73 Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would want to take in the api key as a config variable in run.yaml when we spin up the server

@raghotham, I believe that's the behaviour at the moment. This is how fireworks and together define their configs:

api_key: Optional[str] = Field(
default=None,
description="The Fireworks.ai API Key",
)

api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key",
)

And it's in the run.yaml that you define the environment variable:

config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}

wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes @aidando73 this is correct. Note that both Together and Fireworks also support grabbing the api key from headers via the NeedsProviderData mixin. You can add that if you feel like it.

Copy link
Contributor Author

@aidando73 aidando73 Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - added the mixin

Added some client code in the test plan as well to test

]:

if model_id == "llama-3.2-3b-preview":
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very user friendly +1

logprobs=None,
frequency_penalty=None,
stream=request.stream,
# Groq only supports n=1 at the time of writing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the LS structures only support responses w/ 1 choice, so the fact that groq only supports n=1 is moot. you should even skip passing the param around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

raise ValueError(f"Invalid finish reason: {finish_reason}")


async def convert_chat_completion_response_stream(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in some other PR, we should merge this into the general util module.

Copy link
Contributor Author

@aidando73 aidando73 Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this function is too coupled to Groq types to be used as general util function? E.g., this one takes in a ChatCompletionChunk from groq.types.chat.chat_completion_chunk

@aidando73 aidando73 force-pushed the aidand-432-groq_2 branch 3 times, most recently from 2f1522a to 3587f08 Compare December 14, 2024 00:42

if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")
Copy link
Contributor Author

@aidando73 aidando73 Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aidando73 aidando73 changed the title Add Groq Provider - chat completions [#432] Add Groq Provider - chat completions Dec 15, 2024


async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the correct comment should be import dynamically so the import is used only when it is needed :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know what does this suffix indicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find anything online. @ricklamers @philass - could you provide any additional context here?


yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems unnecessary to refactor the event type generator separately. just maintain an index here or even more simply just repeat a small amount of code to send back a start chunk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thank you very much!

@ashwinb
Copy link
Contributor

ashwinb commented Dec 20, 2024

(will merge after a couple small comments are addressed)

@aidando73
Copy link
Contributor Author

Ok @ashwinb I've addressed your comments

@cheesecake100201
Copy link
Contributor

Can this PR be merged soon ?
@ashwinb @aidando73

@ashwinb
Copy link
Contributor

ashwinb commented Jan 3, 2025

Thanks for the reminder @cheesecake100201! Sorry was out on vacation and missed this one after coming back.

@ashwinb ashwinb merged commit e1f42eb into meta-llama:main Jan 3, 2025
2 checks passed
@aidando73
Copy link
Contributor Author

@ashwinb I missed a variable name while refactoring could you PTAL #719 👈 ?

ashwinb pushed a commit that referenced this pull request Jan 3, 2025
# What does this PR do?

Contributes towards: #432

RE: #609

I missed this one while refactoring. Fixes:

```python
Traceback (most recent call last):
  File "/Users/aidand/dev/llama-stack/llama_stack/distribution/server/server.py", line 191, in endpoint
    return await maybe_await(value)
  File "/Users/aidand/dev/llama-stack/llama_stack/distribution/server/server.py", line 155, in maybe_await
    return await value
  File "/Users/aidand/dev/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 101, in async_wrapper
    result = await method(self, *args, **kwargs)
  File "/Users/aidand/dev/llama-stack/llama_stack/distribution/routers/routers.py", line 156, in chat_completion
    return await provider.chat_completion(**params)
  File "/Users/aidand/dev/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 101, in async_wrapper
    result = await method(self, *args, **kwargs)
  File "/Users/aidand/dev/llama-stack/llama_stack/providers/remote/inference/groq/groq.py", line 127, in chat_completion
    response = self._get_client().chat.completions.create(**request)
  File "/Users/aidand/dev/llama-stack/llama_stack/providers/remote/inference/groq/groq.py", line 143, in _get_client
    return Groq(api_key=self.config.api_key)
AttributeError: 'GroqInferenceAdapter' object has no attribute 'config'. Did you mean: '_config'?
```


## Test Plan

Environment:

```shell
export GROQ_API_KEY=<api-key>

# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml

# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs

# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda

# Activate built environment
conda activate llamastack-groq
```
<details>
<summary>Manual</summary>

```bash
llama stack run ./run.yaml --port 5001
```

Via this Jupyter notebook:
https://github.com/aidando73/llama-stack/blob/9165502582cd7cb178bc1dcf89955b45768ab6c1/hello.ipynb
</details>


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants